// SPDX-License-Identifier: GPL-2.0-only
/*
 * Copyright (C) 2016 Google, Inc.
 *
 * Original Code by Pavel Labath <labath@google.com>
 *
 * Code modified by Pratyush Anand <panand@redhat.com>
 * for testing different byte select for each access size.
 */

#define _GNU_SOURCE

#include <asm/ptrace.h>
#include <sys/types.h>
#include <sys/wait.h>
#include <sys/ptrace.h>
#include <sys/param.h>
#include <sys/uio.h>
#include <stdint.h>
#include <stdbool.h>
#include <stddef.h>
#include <string.h>
#include <stdio.h>
#include <unistd.h>
#include <elf.h>
#include <errno.h>
#include <signal.h>
#include <limits.h>

#include "kselftest.h"

#ifndef TRAP_HWBKPT
#define TRAP_HWBKPT 0x04
#endif

static long _iteration = 4;

long parse_iteration(const char *s) {
    errno = 0;
    long n = strtol(s, NULL, 10);

    if (n < 0 || (errno == ERANGE && n == LONG_MAX)) {
        return -1;
    }
    return n;
}


static void child(volatile int *addr)
{
  size_t size = sizeof(int);
  if (ptrace(PTRACE_TRACEME, 0, NULL, NULL) != 0) {
    ksft_print_msg(
      "ptrace(PTRACE_TRACEME) failed: %s\n",
      strerror(errno));
    _exit(1);
  }

  if (raise(SIGSTOP) != 0) {
    ksft_print_msg(
      "raise(SIGSTOP) failed: %s\n", strerror(errno));
    _exit(1);
  }

  if ((uintptr_t) addr % size) {
    ksft_print_msg(
      "Wrong address write for the given size: %s\n",
      strerror(errno));
    _exit(1);
  }

  ksft_print_msg("Memory at %p will be written %ld times\n", addr, _iteration);
  for (long i = 0; i < _iteration; i++) {
    *addr = 47;
  }

  _exit(0);
}

static bool set_watchpoint(pid_t pid, const volatile int *addr)
{
  size_t size = sizeof(int);
  const int offset = (uintptr_t)addr % 8;
  const unsigned int byte_mask = ((1 << size) - 1) << offset;
  const unsigned int type = 2; /* Write */
  const unsigned int enable = 1;
  const unsigned int control = byte_mask << 5 | type << 3 | enable;
  struct user_hwdebug_state dreg_state;
  struct iovec iov;

  memset(&dreg_state, 0, sizeof(dreg_state));
  dreg_state.dbg_regs[0].addr = (uintptr_t)(addr - offset);
  dreg_state.dbg_regs[0].ctrl = control;
  iov.iov_base = &dreg_state;
  iov.iov_len = offsetof(struct user_hwdebug_state, dbg_regs) +
        sizeof(dreg_state.dbg_regs[0]);

  ksft_print_msg("Setting hardware watchpoint 0 with: byte_mask 0x%lx, "
    "addr 0x%lx, ctrl 0x%lx\n",
    (unsigned long) byte_mask,
    (unsigned long) dreg_state.dbg_regs[0].addr,
    (unsigned long) dreg_state.dbg_regs[0].ctrl);

  if (ptrace(PTRACE_SETREGSET, pid, NT_ARM_HW_WATCH, &iov) == 0)
    return true;

  if (errno == EIO)
    ksft_print_msg(
      "ptrace(PTRACE_SETREGSET, NT_ARM_HW_WATCH) not supported on this hardware: %s\n",
      strerror(errno));

  ksft_print_msg(
    "ptrace(PTRACE_SETREGSET, NT_ARM_HW_WATCH) failed: %s\n",
    strerror(errno));
  return false;
}

static void print_dreg(pid_t pid) {
  struct user_hwdebug_state dreg_state;
  struct iovec iov;

  memset(&dreg_state, 0, sizeof(dreg_state));
  iov.iov_base = &dreg_state;
  iov.iov_len = offsetof(struct user_hwdebug_state, dbg_regs) +
    sizeof(dreg_state.dbg_regs[0]);

  if (ptrace(PTRACE_GETREGSET, pid, NT_ARM_HW_WATCH, &iov) == 0)
    ksft_print_msg("Value in hardware watchpoint 0: "
      "addr 0x%lx, ctrl 0x%lx\n",
      (unsigned long) dreg_state.dbg_regs[0].addr,
      (unsigned long) dreg_state.dbg_regs[0].ctrl);
  else
    ksft_print_msg("ptrace(PTRACE_GETREGSET, NT_ARM_HW_WATCH) failed: %s\n",
      strerror(errno));
}

static bool run_test(int *addr)
{
  int status;
  siginfo_t siginfo;
  pid_t pid = fork();
  pid_t wpid;

  if (pid < 0) {
    ksft_test_result_fail(
      "fork() failed: %s\n", strerror(errno));
    return false;
  }
  if (pid == 0)
    child(addr);

  wpid = waitpid(pid, &status, __WALL);
  if (wpid != pid) {
    ksft_print_msg(
      "waitpid() failed: %s\n", strerror(errno));
    return false;
  }
  if (!WIFSTOPPED(status)) {
    ksft_print_msg(
      "child did not stop: %s\n", strerror(errno));
    return false;
  }
  if (WSTOPSIG(status) != SIGSTOP) {
    ksft_print_msg("child did not stop with SIGSTOP\n");
    return false;
  }

  if (!set_watchpoint(pid, addr))
    return false;

  print_dreg(pid);

  long i = 0;
  for (; i < _iteration; i++) {
    if (ptrace(PTRACE_CONT, pid, NULL, NULL) < 0) {
      ksft_print_msg(
        "ptrace(PTRACE_CONT) failed: %s\n",
        strerror(errno));
      goto error;
    }

    alarm(3);
    wpid = waitpid(pid, &status, __WALL);
    if (wpid != pid) {
      ksft_print_msg(
        "waitpid() failed: %s\n", strerror(errno));
      goto error;
    }
    alarm(0);
    if (WIFEXITED(status)) {
      ksft_print_msg("child exited prematurely\n");
      goto error;
    }
    if (!WIFSTOPPED(status)) {
      ksft_print_msg("child did not stop\n");
      goto error;
    }

    if (WSTOPSIG(status) != SIGTRAP) {
      ksft_print_msg("child did not stop with SIGTRAP\n");
      goto error;
    }
    if (ptrace(PTRACE_GETSIGINFO, pid, NULL, &siginfo) != 0) {
      ksft_print_msg(
        "ptrace(PTRACE_GETSIGINFO): %s\n",
        strerror(errno));
      goto error;
    }
    if (siginfo.si_code != TRAP_HWBKPT) {
      ksft_print_msg(
        "Unexpected si_code %d\n", siginfo.si_code);
      goto error;
    }
  }

  ksft_print_msg("Watchpoint set at %p has been triggered %ld time(s)\n",
    addr, i);

  kill(pid, SIGKILL);
  wpid = waitpid(pid, &status, 0);
  if (wpid != pid) {
    ksft_print_msg(
      "waitpid() failed: %s\n", strerror(errno));
    return false;
  }
  return true;

error:
  ksft_print_msg("Watchpoint set at %p has been triggered %ld time(s)\n",
    addr, i);
  return false;
}

static void sigalrm(int sig)
{
}

int main(int argc, char **argv)
{
  int opt;
  bool succeeded = true;
  struct sigaction act;
  bool result;

  long it;
  if ((argc == 2) && ((it = parse_iteration(argv[1])) > 0)) {
      _iteration = it;
  }

  ksft_print_header();
  ksft_set_plan(1);

  act.sa_handler = sigalrm;
  sigemptyset(&act.sa_mask);
  act.sa_flags = 0;
  sigaction(SIGALRM, &act, NULL);

  static int number __attribute__((aligned(8)));
  result = run_test(&number);
  char result_msg[256];
  snprintf(result_msg, sizeof(result_msg),
    "Test watchpoint set at %p with size %zu\n", &number, sizeof(number));
  if (result)
    ksft_test_result_pass(result_msg);
  else {
    ksft_test_result_fail(result_msg);
    succeeded = false;
  }

  if (succeeded)
    ksft_exit_pass();
  else
    ksft_exit_fail();
}
